{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of how to train a DPN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "from imgaug import augmenters as iaa\n",
    "from keras.callbacks import EarlyStopping\n",
    "import numpy as np\n",
    "\n",
    "# Add root directory to the python search path, if it is not already in there.\n",
    "root_dir = os.path.abspath(os.path.join(\"..\"))\n",
    "if root_dir not in sys.path:\n",
    "    sys.path.append(root_dir)\n",
    "    \n",
    "# Import the DeepParticleNet modules.\n",
    "from dpn.config import Config\n",
    "from dpn.dataset import Dataset\n",
    "from dpn.model import Model\n",
    "\n",
    "from external.CLR.clr_callback import CyclicLR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MpacConfig(Config):\n",
    "    # General\n",
    "    NAME = \"example\"\n",
    "    COMMENT = \"\"\n",
    "\n",
    "    # Hardware\n",
    "    GPU_COUNT = 1\n",
    "    IMAGES_PER_GPU = 1\n",
    "\n",
    "    # Dataset\n",
    "    DATASET_PATH = os.path.join(root_dir, \"datasets\")\n",
    "\n",
    "    DATASET_SUBSET_TRAIN = \"train\"\n",
    "    DATASET_SUBSET_VAL = \"val\"    \n",
    "\n",
    "    NUMBER_OF_SAMPLES_TRAIN = 400\n",
    "    NUMBER_OF_SAMPLES_VAL = 100\n",
    "\n",
    "    MEAN_PIXEL = [60.2, 60.2, 60.2]\n",
    "\n",
    "    # Model\n",
    "    BACKBONE = \"resnet50\"\n",
    "    USE_PRETRAINED_WEIGHTS = \"coco\"\n",
    "\n",
    "    # Augmentation\n",
    "    # http://imgaug.readthedocs.io/en/latest/source/augmenters.html\n",
    "    AUGMENTATION = iaa.SomeOf((0, 2), [\n",
    "        iaa.Fliplr(0.5),\n",
    "        iaa.Flipud(0.5),\n",
    "        iaa.OneOf([iaa.Affine(rotate=90),\n",
    "                   iaa.Affine(rotate=180),\n",
    "                   iaa.Affine(rotate=270)]),\n",
    "        iaa.Multiply((0.8, 1.5)),\n",
    "        iaa.GaussianBlur(sigma=(0.0, 5.0))\n",
    "    ])\n",
    "\n",
    "    # Custom callbacks.\n",
    "    CYCLIC_LEARNING_RATE = CyclicLR(base_lr=0.0005, \n",
    "                                    max_lr=0.0037, \n",
    "                                    step_size=2 * np.ceil(NUMBER_OF_SAMPLES_TRAIN/(GPU_COUNT*IMAGES_PER_GPU)),\n",
    "                                    mode=\"triangular\")\n",
    "\n",
    "    EARLY_STOPPING = EarlyStopping(monitor=\"val_loss\",\n",
    "                                   min_delta=0,\n",
    "                                   patience=20,\n",
    "                                   verbose=0,\n",
    "                                   mode=\"auto\")\n",
    "\n",
    "    CUSTOM_CALLBACKS = [CYCLIC_LEARNING_RATE, EARLY_STOPPING]\n",
    "\n",
    "# Create a config object.\n",
    "config = MpacConfig()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inherit from the Dataset class.\n",
    "class MpacDataset(Dataset):\n",
    "    MONOCLASS = \"sphere\"  # The dataset only has one class.\n",
    "\n",
    "# Training dataset.\n",
    "dataset_train = MpacDataset(config=config, dataset_name=\"training\")\n",
    "# Validation dataset\n",
    "dataset_val = MpacDataset(config=config, dataset_name=\"validation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Directory to save logs and model checkpoints.\n",
    "logging_dir = os.path.join(root_dir, \"logs\")\n",
    "\n",
    "# Create a model object.\n",
    "model = Model(\n",
    "    mode=\"training\",\n",
    "    config=config,\n",
    "    model_dir=logging_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = model.train(dataset_train, dataset_val, save_best_only=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
